import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import keras
import numpy as np
import math
from keras.datasets import cifar10
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, ZeroPadding2D, GlobalAveragePooling2D
from keras.layers import Flatten, Dense, Dropout,BatchNormalization
from keras.models import Model
from keras.layers import Input, concatenate
from keras import optimizers, regularizers
from keras.preprocessing.image import ImageDataGenerator
from keras.initializers import he_normal
from keras.callbacks import LearningRateScheduler, TensorBoard, ModelCheckpoint
import h5py
from keras.preprocessing.image import ImageDataGenerator
from keras.applications import resnet50, inception_v3, xception
from keras.applications.resnet50 import ResNet50
from keras.applications.inception_v3 import InceptionV3
from keras.applications.xception import Xception
from keras.layers.core import Lambda
from keras.layers import Input, GlobalAveragePooling2D, GlobalMaxPooling2D, Dense, Dropout, Conv2D, MaxPooling2D
from keras.models import Model
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from keras.optimizers import adam, Adam
from sklearn.utils import shuffle
from keras.utils import to_categorical
from keras.utils.vis_utils import plot_model
from loss_history import LossHistory
from keras.callbacks import TensorBoard
from keras import backend as K
K.clear_session()
from keras.applications import imagenet_utils
preprocess = imagenet_utils.preprocess_input


BN_NORM = True  # 是否使用 Batch Normalization
CONCAT_AXIS=3
WEIGHT_DECAY=1e-4  # 决定了是否使用 L2 正则
DATA_FORMAT='channels_last'  # Theano:'channels_first', Tensorflow:'channels_last'
EPOCHS = 20
BATCH_SIZE = 32  # batch size
IMAGE_SIZE = (512, 512)  # 图片尺寸
LEARNING_RATE = 1e-4  # 学习率
EARLY_STOPPING_PATIENCE = 3
REDUCE_LR_PATIENCE = 3
CLASS_NUM = 2  # 类别数
DROP_RATE = 0.2


# -------------------------------------
# 文件路径
# -------------------------------------
TRAIN_DATA_PATH = './data/train/'
VALID_DATA_PATH = './data/valid/'
RESULT_PATH = './result/'
MODEL_PATH = './model/'
OUTPUT_PATH = './output/'
MODEL_NAME = 'inceptionv1-model.hdf5'


def my_preprocess(x):
    x = x / 255.0
    return x


def conv2D_bn2d(x, filters, kernel_size, strides=(1,1), padding='same', data_format=DATA_FORMAT,
                 dilation_rate=(1,1), activation='relu', use_bias=True, kernel_initializer='glorot_uniform',
                 bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None,
                 kernel_constraint=None, bias_constraint=None, bn_norm=BN_NORM, weight_decay=WEIGHT_DECAY):

    # 判断是否使用 L2 正则
    if weight_decay:
        kernel_regularizer = regularizers.l2(weight_decay)
        bias_regularizer = regularizers.l2(weight_decay)
    else:
        kernel_regularizer = None
        bias_regularizer = None

    x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format,
             dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer,
             bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
             activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint)(x)

    # 判断是否使用 Batch Normalization
    if bn_norm:
        x=BatchNormalization()(x)

    return x


def inception_module(x, params, concat_axis, padding='same', data_format=DATA_FORMAT, dilation_rate=(1,1),
                     activation='relu', use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros',
                     kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None,
                     bias_constraint=None, bn_norm=BN_NORM, weight_decay=None):
    (branch1, branch2, branch3, branch4)=params

    # 判断是否使用 L2 正则
    if weight_decay:
        kernel_regularizer=regularizers.l2(weight_decay)
        bias_regularizer=regularizers.l2(weight_decay)
    else:
        kernel_regularizer=None
        bias_regularizer=None

    #1x1
    pathway1 = Conv2D(filters=branch1[0], kernel_size=(1,1), strides=1, padding=padding, data_format=data_format,
                    dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
                    activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint)(x)

    #1x1->3x3
    pathway2 = Conv2D(filters=branch2[0], kernel_size=(1,1), strides=1, padding=padding, data_format=data_format,
                    dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
                    activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint)(x)
    pathway2 = Conv2D(filters=branch2[1], kernel_size=(3,3), strides=1, padding=padding, data_format=data_format,
                    dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
                    activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint)(pathway2)

    #1x1->5x5
    pathway3 = Conv2D(filters=branch3[0], kernel_size=(1,1), strides=1, padding=padding, data_format=data_format,
                    dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
                    activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint)(x)
    pathway3 = Conv2D(filters=branch3[1], kernel_size=(5,5), strides=1, padding=padding, data_format=data_format,
                    dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
                    activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint)(pathway3)


    #3x3->1x1
    pathway4 = MaxPooling2D(pool_size=(3,3), strides=1, padding=padding, data_format=DATA_FORMAT)(x)
    pathway4 = Conv2D(filters=branch4[0], kernel_size=(1,1), strides=1, padding=padding, data_format=data_format,
                    dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
                    activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint)(pathway4)

    return concatenate([pathway1, pathway2, pathway3, pathway4], axis=concat_axis)


def create_model(img_input):

    # 卷积层
    x=conv2D_bn2d(img_input, 64, (7,7), 2, padding='same', bn_norm=False)
    # 最大池化层
    x=MaxPooling2D(pool_size=(3,3), strides=2, padding='same', data_format=DATA_FORMAT)(x)
    # BN
    x=BatchNormalization()(x)
    # 1*1 卷积
    x=conv2D_bn2d(x, 64, (1,1), 1, padding='same', bn_norm=False)
    # 卷积 + BN
    x=conv2D_bn2d(x, 192, (3,3), 1, padding='same', bn_norm=True)
    # 最大池化
    x=MaxPooling2D(pool_size=(3,3),strides=2,padding='same',data_format=DATA_FORMAT)(x)
    x=inception_module(x,params=[(64,),(96,128),(16,32),(32,)],concat_axis=CONCAT_AXIS) #3a
    x=inception_module(x,params=[(128,),(128,192),(32,96),(64,)],concat_axis=CONCAT_AXIS) #3b
    x=MaxPooling2D(pool_size=(3,3),strides=2,padding='same',data_format=DATA_FORMAT)(x)
    x=inception_module(x,params=[(192,),(96,208),(16,48),(64,)],concat_axis=CONCAT_AXIS) #4a
    x=inception_module(x,params=[(160,),(112,224),(24,64),(64,)],concat_axis=CONCAT_AXIS) #4b
    x=inception_module(x,params=[(128,),(128,256),(24,64),(64,)],concat_axis=CONCAT_AXIS) #4c
    x=inception_module(x,params=[(112,),(144,288),(32,64),(64,)],concat_axis=CONCAT_AXIS) #4d
    x=inception_module(x,params=[(256,),(160,320),(32,128),(128,)],concat_axis=CONCAT_AXIS) #4e
    x=MaxPooling2D(pool_size=(3,3),strides=2,padding='same',data_format=DATA_FORMAT)(x)
    x=inception_module(x,params=[(256,),(160,320),(32,128),(128,)],concat_axis=CONCAT_AXIS) #5a
    x=inception_module(x,params=[(384,),(192,384),(48,128),(128,)],concat_axis=CONCAT_AXIS) #5b
    # x=AveragePooling2D(pool_size=(1,1),strides=1,padding='valid',data_format=DATA_FORMAT)(x)
    # x=Flatten()(x)
    # x=Dropout(DROP_RATE)(x)
    # x=Dense(output_dim=2, activation='softmax')(x)
    x=GlobalAveragePooling2D()(x)
    x = Dense(2,activation='softmax',kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(WEIGHT_DECAY))(x)
    return x


train_gen = ImageDataGenerator()
valid_gen = ImageDataGenerator()
train_generator = train_gen.flow_from_directory(TRAIN_DATA_PATH, IMAGE_SIZE, shuffle=True, batch_size=BATCH_SIZE, color_mode='grayscale')
valid_generator = valid_gen.flow_from_directory(VALID_DATA_PATH, IMAGE_SIZE, batch_size=BATCH_SIZE, color_mode='grayscale')


inputs = Input((*IMAGE_SIZE, 1))
x = Lambda(my_preprocess)(inputs)

output = create_model(inputs)
model=Model(inputs, output)
model.summary()

plot_model(model, to_file=os.path.join(RESULT_PATH, 'inceptionv1.png'), show_shapes=True)


# check point
check_point = ModelCheckpoint(monitor='val_loss',
                                filepath=os.path.join(MODEL_PATH, MODEL_NAME),
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=False,
                                mode='auto')


# early stopping
early_stopping = EarlyStopping(monitor='val_loss', patience=EARLY_STOPPING_PATIENCE, verbose=0, mode='auto')


# reduce lr
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=REDUCE_LR_PATIENCE, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0)


# 创建一个 LossHistory 实例
history = LossHistory()


# compile
model.compile(optimizer=adam(lr=LEARNING_RATE), loss='binary_crossentropy', metrics=['accuracy'])


# fit
model.fit_generator(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=valid_generator,
    validation_steps=valid_generator.samples // BATCH_SIZE,
    callbacks=[check_point, early_stopping, history]
)


# 绘制 loss 曲线和 batch 曲线
history.loss_plot('batch', os.path.join(RESULT_PATH, 'inception_v1_loss_batch.png'))
history.acc_plot('batch', os.path.join(RESULT_PATH, 'inception_v1_acc_batch.png'))
history.loss_plot('epoch', os.path.join(RESULT_PATH, 'inception_v1_loss_epoch.png'))
history.acc_plot('epoch', os.path.join(RESULT_PATH, 'inception_v1_acc_epoch.png'))
